{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## FARM: Use your own dataset\n",
    "    \n",
    "In Tutorial 1 you already learned about the major building blocks.\n",
    "In this tutorial, you will see how to use FARM with your own dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's start by adjust the working directory so that it is the root of the repository\n",
    "# This should be run just once.\n",
    "\n",
    "import os\n",
    "os.chdir('../')\n",
    "print(\"Current working directory is {}\".format(os.getcwd()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1) How a Processor works"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Architecture\n",
    "The Processor converts a <b>raw input (e.g File) into a Pytorch dataset</b>.   \n",
    "For using an own dataset we need to adjust this Processor.\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/deepset-ai/FARM/master/docs/img/data_silo_no_bg.jpg\" width=\"400\" height=\"400\" align=\"left\"/>\n",
    "<br/><br/>\n",
    "<br/><br/>\n",
    "<br/><br/>\n",
    "<br/><br/>\n",
    "<br/><br/>\n",
    "<br/><br/>\n",
    "<br/><br/>\n",
    "\n",
    "​\n",
    "### Main Conversion Stages \n",
    "1. Read from file / raw input \n",
    "2. Create samples\n",
    "3. Featurize samples\n",
    "4. Create PyTorch Dataset\n",
    "\n",
    "### Functions to implement\n",
    "1. file\\_to_dicts()\n",
    "2. \\_dict_to_samples()\n",
    "3. \\_sample_to_features()  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example: TextClassificationProcessor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from farm.data_handler.processor import *\n",
    "from farm.data_handler.samples import Sample\n",
    "from farm.modeling.tokenization import BertTokenizer\n",
    "#from farm.modeling.tokenization import tokenize_with_metadata\n",
    "\n",
    "import os\n",
    "\n",
    "class TextClassificationProcessor(Processor):\n",
    "    \"\"\"\n",
    "    Used to handle the text classification datasets that come in tabular format (CSV, TSV, etc.)\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        tokenizer,\n",
    "        max_seq_len,\n",
    "        data_dir,\n",
    "        label_list=None,\n",
    "        metric=None,\n",
    "        train_filename=\"train.tsv\",\n",
    "        dev_filename=None,\n",
    "        test_filename=\"test.tsv\",\n",
    "        dev_split=0.1,\n",
    "        delimiter=\"\\t\",\n",
    "        quote_char=\"'\",\n",
    "        skiprows=None,\n",
    "        label_column_name=\"label\",\n",
    "        multilabel=False,\n",
    "        header=0,\n",
    "        **kwargs,\n",
    "    ):\n",
    "        #TODO If an arg is misspelt, e.g. metrics, it will be swallowed silently by kwargs\n",
    "\n",
    "        # Custom processor attributes\n",
    "        self.delimiter = delimiter\n",
    "        self.quote_char = quote_char\n",
    "        self.skiprows = skiprows\n",
    "        self.header = header\n",
    "\n",
    "        super(TextClassificationProcessor, self).__init__(\n",
    "            tokenizer=tokenizer,\n",
    "            max_seq_len=max_seq_len,\n",
    "            train_filename=train_filename,\n",
    "            dev_filename=dev_filename,\n",
    "            test_filename=test_filename,\n",
    "            dev_split=dev_split,\n",
    "            data_dir=data_dir,\n",
    "            tasks={},\n",
    "        )\n",
    "        #TODO raise info when no task is added due to missing \"metric\" or \"labels\" arg\n",
    "        if metric and label_list:\n",
    "            if multilabel:\n",
    "                task_type = \"multilabel_classification\"\n",
    "            else:\n",
    "                task_type = \"classification\"\n",
    "            self.add_task(name=\"text_classification\",\n",
    "                          metric=metric,\n",
    "                          label_list=label_list,\n",
    "                          label_column_name=label_column_name,\n",
    "                          task_type=task_type)\n",
    "\n",
    "    def file_to_dicts(self, file: str) -> [dict]:\n",
    "        column_mapping = {task[\"label_column_name\"]: task[\"label_name\"] for task in self.tasks.values()}\n",
    "        dicts = read_tsv(\n",
    "            filename=file,\n",
    "            delimiter=self.delimiter,\n",
    "            skiprows=self.skiprows,\n",
    "            quotechar=self.quote_char,\n",
    "            rename_columns=column_mapping,\n",
    "            header=self.header\n",
    "            )\n",
    "\n",
    "        return dicts\n",
    "\n",
    "    def _dict_to_samples(self, dict: dict, **kwargs) -> [Sample]:\n",
    "        # this tokenization also stores offsets\n",
    "        tokenized = tokenize_with_metadata(dict[\"text\"], self.tokenizer, self.max_seq_len)\n",
    "        return [Sample(id=None, clear_text=dict, tokenized=tokenized)]\n",
    "\n",
    "    def _sample_to_features(self, sample) -> dict:\n",
    "        features = sample_to_features_text(\n",
    "            sample=sample,\n",
    "            tasks=self.tasks,\n",
    "            max_seq_len=self.max_seq_len,\n",
    "            tokenizer=self.tokenizer,\n",
    "        )\n",
    "        return features\n",
    "      \n",
    "      \n",
    "# Helper\n",
    "def read_tsv(filename, rename_columns, quotechar='\"', delimiter=\"\\t\", skiprows=None, header=0):\n",
    "    \"\"\"Reads a tab separated value file. Tries to download the data if filename is not found\"\"\"\n",
    "    \n",
    "    # get remote dataset if needed\n",
    "    if not (os.path.exists(filename)):\n",
    "        logger.info(f\" Couldn't find {filename} locally. Trying to download ...\")\n",
    "        _download_extract_downstream_data(filename)\n",
    "    \n",
    "    # read file into df\n",
    "    df = pd.read_csv(\n",
    "        filename,\n",
    "        sep=delimiter,\n",
    "        encoding=\"utf-8\",\n",
    "        quotechar=quotechar,\n",
    "        dtype=str,\n",
    "        skiprows=skiprows,\n",
    "        header=header\n",
    "    )\n",
    "\n",
    "    # let's rename our target columns to the default names FARM expects: \n",
    "    # \"text\": contains the text\n",
    "    # \"text_classification_label\": contains a label for text classification\n",
    "    columns = [\"text\"] + list(rename_columns.keys())\n",
    "    df = df[columns]\n",
    "    for source_column, label_name in rename_columns.items():\n",
    "        df[label_name] = df[source_column]\n",
    "        df.drop(columns=[source_column], inplace=True)\n",
    "    \n",
    "    if \"unused\" in df.columns:\n",
    "        df.drop(columns=[\"unused\"], inplace=True)\n",
    "    raw_dict = df.to_dict(orient=\"records\")\n",
    "    return raw_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The default format is: \n",
    "# - tab separated\n",
    "# - column \"text\"\n",
    "# - column \"label\" \n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "df = pd.DataFrame({\"text\": [\"The concerts supercaliphractisch was great!\", \"I hate people ignoring climate change.\"],\n",
    "                  \"label\": [\"positive\",\"negative\"]\n",
    "                  })\n",
    "print(df)\n",
    "df.to_csv(\"train.tsv\", sep=\"\\t\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained(\n",
    "    pretrained_model_name_or_path=\"bert-base-uncased\")\n",
    "\n",
    "processor = TextClassificationProcessor(data_dir = \"\", \n",
    "                                        tokenizer=tokenizer,\n",
    "                                        max_seq_len=64,\n",
    "                                        label_list=[\"positive\",\"negative\"],\n",
    "                                        label_column_name=\"label\",\n",
    "                                        metric=\"acc\",\n",
    "                                       )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#  1. One File -> Dictionarie(s) with \"raw data\"\n",
    "dicts = processor.file_to_dicts(file=\"train.tsv\")\n",
    "print(dicts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#  2. One Dictionary -> Sample(s) \n",
    "#     (Sample = \"clear text\" model input + meta information) \n",
    "samples = processor._dict_to_samples(dict=dicts[0])\n",
    "# print each attribute of sample\n",
    "print(samples[0].clear_text)\n",
    "print(samples[0].tokenized)\n",
    "print(samples[0].features)\n",
    "print(\"----------------------------------\\n\\n\\n\")\n",
    "# or in a nicer, formatted style\n",
    "print(samples[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3. One Sample -> Features\n",
    "#    (Features = \"vectorized\" model input)\n",
    "features = processor._sample_to_features(samples[0])\n",
    "print(features[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2) Hands-On: Adjust it to your dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Task 1: Use an existing Processor\n",
    "\n",
    "This works if you have:\n",
    "- standard tasks\n",
    "- common file formats \n",
    "\n",
    "**Example: Text classification on CSV with multiple columns**\n",
    "\n",
    "Dataset: GermEval18 (Hatespeech detection)  \n",
    "Format: TSV  \n",
    "Columns: `text coarse_label fine_label`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download dataset\n",
    "from farm.data_handler import utils\n",
    "utils._download_extract_downstream_data(\"germeval18/train.tsv\")\n",
    "!head -n 10 germeval18/train.tsv\n",
    "\n",
    "# TODO: Initialize a processor for the above file by passing the right arguments\n",
    "\n",
    "processor = TextClassificationProcessor(tokenizer=tokenizer,\n",
    "                                        max_seq_len=128,\n",
    "                                        data_dir=\"germeval18\",\n",
    "                                        train_filename=\"train.tsv\",\n",
    "                                        label_list=[\"OTHER\",\"OFFENSE\"],\n",
    "                                        metric=\"acc\",\n",
    "                                        label_column_name=\"coarse_label\"\n",
    "                                        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test it\n",
    "dicts = processor.file_to_dicts(file=\"germeval18/train.tsv\")\n",
    "print(dicts[0])\n",
    "assert dicts[0] == {'text': '@corinnamilborn Liebe Corinna, wir würden dich gerne als Moderatorin für uns gewinnen! Wärst du begeisterbar?', 'text_classification_label': 'OTHER'}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Task 2: Build your own Processor\n",
    "This works best for:\n",
    "- custom input files\n",
    "- special preprocessing steps\n",
    "- advanced multitask learning \n",
    "\n",
    "**Example: Text classification with JSON as input file** \n",
    "\n",
    "Dataset: [100k Yelp reviews](https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-downstream/yelp_reviews_100k.json) ( [full dataset](https://https://www.yelp.com/dataset/download), [documentation](https://https://www.yelp.com/dataset/documentation/main))\n",
    "\n",
    "Format: \n",
    "\n",
    "``` \n",
    "{\n",
    "...\n",
    "    // integer, star rating\n",
    "    \"stars\": 4,\n",
    "\n",
    "    // string, the review itself\n",
    "    \"text\": \"Great place to hang out after work: the prices are decent, and the ambience is fun. It's a bit loud, but very lively. The staff is friendly, and the food is good. They have a good selection of drinks.\",\n",
    "...\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download dataset\n",
    "!wget https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-downstream/yelp_reviews_100k.json\n",
    "!head -5 yelp_reviews_100k.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "# TODO: Create a new Processor class and overwrite the function that reads from the file\n",
    "# The dicts created should look like this to comply with the default TextClassificationProcessor.\n",
    "#{'text': 'Total bill for this horrible service? ...',\n",
    "# 'text_classification_label': '4'}\n",
    "\n",
    "\n",
    "class CustomTextClassificationProcessor(TextClassificationProcessor):\n",
    "  \n",
    "    # we need to overwrite this function from the parent class\n",
    "    def file_to_dicts(self, file: str) -> [dict]:\n",
    "      # read into df\n",
    "      df = pd.read_json(file, lines=True)\n",
    "      # rename\n",
    "      df[\"text_classification_label\"] = df[\"stars\"].astype(str)\n",
    "      # drop unused\n",
    "      columns = [\"text_classification_label\",\"text\"]\n",
    "      df = df[columns]\n",
    "      # convert to dicts\n",
    "      dicts = df.to_dict(orient=\"records\")\n",
    "      return dicts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "processor = CustomTextClassificationProcessor(tokenizer=tokenizer,\n",
    "                                              max_seq_len=128,\n",
    "                                              data_dir=\"\",\n",
    "                                              label_list=[\"1\",\"2\",\"3\",\"4\",\"5\"],\n",
    "                                              metric=\"acc\",\n",
    "                                              )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test it\n",
    "\n",
    "dicts = processor.file_to_dicts(file=\"yelp_reviews_100k.json\")\n",
    "print(dicts[0])\n",
    "\n",
    "assert dicts[0] == {'text_classification_label': '1', 'text': 'Total bill for this horrible service? Over $8Gs. These crooks actually had the nerve to charge us $69 for 3 pills. I checked online the pills can be had for 19 cents EACH! Avoid Hospital ERs at all costs.'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
